import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax
from jax import config

from typing import Tuple

import datetime
import gym
import numpy as np
import tqdm
import time
import absl
import sys
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter
from dataclasses import dataclass

import d4rl

import wrappers
from dataset_utils import D4RLDataset, split_into_trajectories, merge_datasets
from evaluation import evaluate
from learner import Learner, PAGAR_Learner
import warnings
from logging_utils.logx import EpochLogger
import math
import pickle
import json



FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'halfcheetah-expert-v2', 'Environment name.')
flags.DEFINE_string('f', 'chi-square', 'f-divergence to use.[chi-square, total-variation, reverse-KL(XQL)]')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_string('exp_name', 'dump', 'Epoch logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.')
flags.DEFINE_float('temp', 1.0, 'Loss temperature')
flags.DEFINE_boolean('double', True, 'Use double q-learning')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_integer('sample_random_times', 0, 'Number of random actions to add to smooth dataset')
flags.DEFINE_boolean('grad_pen', False, 'Add a gradient penalty to critic network')
flags.DEFINE_float('lambda_gp', 1, 'Gradient penalty coefficient')
flags.DEFINE_float('max_clip', 7., 'Loss clip value')
flags.DEFINE_integer('num_v_updates', 1, 'Number of value updates per iter')
flags.DEFINE_boolean('log_loss', False, 'Use log gumbel loss')
flags.DEFINE_boolean('noise', False, 'Add noise to actions')
flags.DEFINE_float('noise_std', 0.1, 'Noise std for actions')

config_flags.DEFINE_config_file(
    'config',
    'default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)



@dataclass(frozen=True)
class ConfigArgs:
    f : str
    sample_random_times: int
    grad_pen: bool
    noise: bool
    noise_std: float
    lambda_gp: int
    max_clip: float
    num_v_updates: int
    log_loss: bool

def get_normalized_returns(env_name, dataset):
    ret = 0
    norm_rets = []
    for i in range(dataset.size):
        reward = dataset.rewards[i]
        mask = dataset.dones_float[i]
        ret += reward
        if mask == 1.:
            #print(dataset.dones_float[i], dataset.masks[i])
            norm_rets.append(d4rl.get_normalized_score(env_name, ret))
            ret = 0
     
    return norm_rets

def normalize(dataset):

    trajs = split_into_trajectories(dataset.observations, dataset.actions,
                                    dataset.rewards, dataset.masks,
                                    dataset.dones_float,
                                    dataset.next_observations,
                                    dataset.is_experts
                                    )

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _, _ in traj:
            episode_return += rew

        return episode_return

    trajs.sort(key=compute_returns)

    dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0])
    dataset.rewards *= 1000.0

def make_env(env_name: str, seed: int) -> gym.Env:
    env = gym.make(env_name)

    env = wrappers.EpisodeMonitor(env)
    env = wrappers.SinglePrecision(env)

    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    return env
    
def make_env_and_dataset(env_name: str,
                         seed: int) -> Tuple[gym.Env, D4RLDataset]:
    env = make_env(env_name, seed)

    dataset = D4RLDataset(env)

    if 'antmaze' in FLAGS.env_name:
        dataset.rewards -= 1.0
        # dataset.rewards = (dataset.rewards - 0.5) * 4
        # See https://github.com/aviralkumar2907/CQL/blob/master/d4rl/examples/cql_antmaze_new.py#L22
    elif ('halfcheetah' in FLAGS.env_name or 'walker2d' in FLAGS.env_name
          or 'hopper' in FLAGS.env_name):
        normalize(dataset)

    return env, dataset


def main(_): 
    eval_return_lst = []
    pagar_eval_return_lst = []
    for i in range(1, 5):
        seed = i * FLAGS.seed

        ts_str = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d_%H-%M-%S")
        save_dir = os.path.join(FLAGS.save_dir, ts_str)
        exp_id = f"results/offline_rl/{FLAGS.env_name}/" + FLAGS.exp_name
        log_folder = exp_id + '/'+FLAGS.exp_name+'_s'+str(seed) 
        logger_kwargs={'output_dir':log_folder, 'exp_name':FLAGS.exp_name}

        import pandas as pd
        progress_file = f'/progress_{ts_str}'
        if os.path.isfile(log_folder+progress_file+'.txt'):
            try:
                df = pd.read_csv(log_folder+progress_file+'.txt',delim_whitespace=True)
                if  df['Iterations'].to_numpy()[-1]>900000:
                    print("Exiting because already trained")
                    exit()
            except:
                print("Cannot read progress.txt")
                pass
        e_logger = EpochLogger(**logger_kwargs)
        hparam_str_dict = dict(seed=seed, env=FLAGS.env_name)
        hparam_str = ','.join([
            '%s=%s' % (k, str(hparam_str_dict[k]))
            for k in sorted(hparam_str_dict.keys())
        ])

        os.makedirs(save_dir, exist_ok=True)


        env, dataset = None, None
        dataset_path = os.path.join(os.path.dirname(__file__), FLAGS.env_name + '_dataset.pt')
        strs = FLAGS.env_name.split('-')
        env_name = strs[0]
        other = ''
        version = ''
        if '+' not in FLAGS.env_name:
            env, dataset = make_env_and_dataset(FLAGS.env_name, seed)
        else:
            other_expert = '_'.join(strs[1:-3])
            version = strs[-3]
            num_exps = strs[-2]
            ratio_exps = strs[-1]  
            other = other_expert.split('+')[0]
            expert = other_expert.split('+')[1]
            if os.path.exists(dataset_path): 
                env = make_env('-'.join([env_name, other, version]), seed)
                with open(dataset_path, 'rb') as f:
                    (expert_dataset, suboptimal_dataset, mix_dataset) = pickle.load(f)
            else:      
                env_name, _, version, num_exps, ratio_exps = FLAGS.env_name.split('-')
                print(f"load env {'-'.join([env_name, other, version])}")
                env, suboptimal_dataset = make_env_and_dataset('-'.join([env_name, other, version]), seed)
            
                _, expert_dataset = make_env_and_dataset('-'.join([env_name, expert, version]), seed)

                expert_dataset = expert_dataset.sample_trajectories(num = int(num_exps) if num_exps != 'inf' else None, return_dataset=True)

                mix_dataset = merge_datasets(suboptimal_dataset, expert_dataset)
                ratio_exps = float(ratio_exps)
                repeat = math.ceil((suboptimal_dataset.size * ratio_exps / (1 - ratio_exps))/expert_dataset.size - 1)
                for _ in range(repeat):
                    mix_dataset.merge(expert_dataset)
                # Must manually set rewards to 0 for recoil. critic.py updates q network with bellman equation
                mix_dataset.rewards *= 0.

                with open(dataset_path, 'wb') as f:
                    pickle.dump((expert_dataset, suboptimal_dataset, mix_dataset), f)
        
        print(expert_dataset.size, suboptimal_dataset.size, suboptimal_dataset.dones_float.sum())

        expert_norm_rets = get_normalized_returns('-'.join([env_name, other, version]), expert_dataset)
    
        kwargs = dict(FLAGS.config)
        args = ConfigArgs(f = FLAGS.f,
                        sample_random_times=FLAGS.sample_random_times,
                        grad_pen=FLAGS.grad_pen,
                        lambda_gp=FLAGS.lambda_gp,
                        noise=FLAGS.noise,
                        max_clip=FLAGS.max_clip,
                        num_v_updates=FLAGS.num_v_updates,
                        log_loss=FLAGS.log_loss,
                        noise_std=FLAGS.noise_std)
        agent = Learner(seed,
                        env.observation_space.sample()[np.newaxis],
                        env.action_space.sample()[np.newaxis],
                        max_steps=FLAGS.max_steps,
                        loss_temp=FLAGS.temp,
                        double_q=FLAGS.double,
                        vanilla=False,
                        args=args,
                        **kwargs)

        pagar_agent = PAGAR_Learner(seed,
                        env.observation_space.sample()[np.newaxis],
                        env.action_space.sample()[np.newaxis],
                        max_steps=FLAGS.max_steps,
                        loss_temp=FLAGS.temp,
                        double_q=FLAGS.double,
                        vanilla=False,
                        args=args,
                        **kwargs)

        best_eval_returns = -np.inf
        pagar_best_eval_returns = -np.inf
        eval_returns = []
        pagar_eval_returns = []
        for i in range(1, FLAGS.max_steps + 1): # Remove TQDM
            suboptimal_batch = suboptimal_dataset.sample(FLAGS.batch_size)
            expert_batch = expert_dataset.sample(FLAGS.batch_size)
            mix_batch = mix_dataset.sample(FLAGS.batch_size)

            update_info = agent.update(expert_batch, suboptimal_batch, mix_batch)
            pagar_update_info = pagar_agent.update(expert_batch, suboptimal_batch, mix_batch)

            #print(update_info)
            #print(pagar_update_info)

            if np.isnan(update_info['actor_loss']):
                print(update_info)
                exit(0)
            
            if np.isnan(pagar_update_info['protagonist_actor_loss']) or \
                np.isnan(pagar_update_info['antagonist_actor_loss']) or \
                    np.isnan(pagar_update_info['pagar_reward_loss']):
                print(pagar_update_info)
                exit(0)

            if i % FLAGS.eval_interval == 0:
                eval_stats = evaluate(agent, env, FLAGS.eval_episodes)
                if eval_stats['return'] >= best_eval_returns:
                    # Store best eval returns
                    best_eval_returns = eval_stats['return']
                e_logger.log_tabular('Iterations', i)
                e_logger.log_tabular('AverageNormalizedReturn', eval_stats['return'])
                for k, v in update_info.items():
                    if 'loss' in k:
                        e_logger.log_tabular(f'{k}', f'{v}') 
                #e_logger.dump_tabular()

                eval_returns.append(eval_stats['return'])

                pagar_eval_stats = evaluate(pagar_agent, env, FLAGS.eval_episodes)
                if pagar_eval_stats['return'] >= pagar_best_eval_returns:
                    # Store best eval returns
                    pagar_best_eval_returns = pagar_eval_stats['return']
                e_logger.log_tabular('pagar_Iterations', i)
                e_logger.log_tabular('pagar_AverageNormalizedReturn', pagar_eval_stats['return'])
                for k, v in pagar_update_info.items():
                    if 'loss' in k:
                        if 'pagar' not in k:
                            k = 'pagar_' + k
                        e_logger.log_tabular(f'{k}', f'{v}') 
                e_logger.dump_tabular()

                pagar_eval_returns.append(pagar_eval_stats['return'])


        with open(log_folder+progress_file+'.pt', 'wb') as fp:
            pickle.dump(
                {
                    'seed': seed, 
                    'recoil': eval_returns, 
                    'pgar_recoil': pagar_eval_returns
                    }, fp)

        eval_return_lst.append(eval_returns[-1])
        pagar_eval_return_lst.append(pagar_eval_returns[-1])

    #print('avg_expert', np.mean(expert_norm_rets), 'std_expert', np.std(expert_norm_rets))
    print('avg_recoil', np.mean(eval_return_lst), 'std_recoil', np.std(eval_return_lst))
    print('avg_pagar', np.mean(pagar_eval_return_lst), 'std_pagar', np.std(pagar_eval_return_lst))
    sys.exit(0)
    os._exit(0)
    raise SystemExit


if __name__ == '__main__':
    app.run(main)
